Tril

返回输入张量的下三角部分,其余部分设置为零

对于输入形状为 \((*, H, W)\) 的张量,返回每个 \(H \times W\) 矩阵的下三角部分。

\[\begin{split}\text{output}[i,j] = \begin{cases} \text{input}[i,j], & \text{if } j \leq i + k \\ 0, & \text{otherwise} \end{cases}\end{split}\]
其中 \(k\) 为对角线偏移量:
  • \(k = 0\):主对角线(默认)

  • \(k > 0\):主对角线上方第 k 条对角线

  • \(k < 0\):主对角线下方第 k 条对角线

输入:
  • src - 输入数据地址。

  • k - 对角线偏移量,默认为 0。

  • height - 矩阵高度。

  • width - 矩阵宽度。

  • out_elems - 输出矩阵的数量(批量大小)。

  • core_mask(int, 可选) - 核掩码(仅适用于共享存储版本)。

输出:
  • dst - 计算结果地址。

支持平台:

FT78NE MT7004

备注

  • FT78NE 支持int8, int16, int32, fp32, fp64, cplx64, cplx128

  • MT7004 支持fp16, fp32, int16, int32, cplx64

共享存储版本:

void i8_tril_s(int8_t *dst, int8_t *src, int core_mask, int64_t k, int64_t height, int64_t width, int64_t out_elems)
void i16_tril_s(int16_t *dst, int16_t *src, int core_mask, int64_t k, int64_t height, int64_t width, int64_t out_elems)
void i32_tril_s(int32_t *dst, int32_t *src, int core_mask, int64_t k, int64_t height, int64_t width, int64_t out_elems)
void hp_tril_s(half *dst, half *src, int core_mask, int64_t k, int64_t height, int64_t width, int64_t out_elems)
void fp_tril_s(float *dst, float *src, int core_mask, int64_t k, int64_t height, int64_t width, int64_t out_elems)
void dp_tril_s(double *dst, double *src, int core_mask, int64_t k, int64_t height, int64_t width, int64_t out_elems)
void c64_tril_s(float (*dst)[2], float (*src)[2], int core_mask, int64_t k, int64_t height, int64_t width, int64_t out_elems)
void c128_tril_s(double (*dst)[2], double (*src)[2], int core_mask, int64_t k, int64_t height, int64_t width, int64_t out_elems)

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3#include <tril.h>
 4
 5int main(int argc, char* argv[]) {
 6    float *input = (float *)0xA0000000;   //input在DDR空间
 7    float *output = (float *)0xB0000000;  //output在DDR空间
 8    int64_t k = 0;           // 主对角线
 9    int64_t height = 4;      // 矩阵高度
10    int64_t width = 4;       // 矩阵宽度
11    int64_t out_elems = 1;   // 矩阵数量
12    int core_mask = 0xff;
13    fp_tril_s(output, input, core_mask, k, height, width, out_elems);
14    return 0;
15}

私有存储版本:

void i8_tril_p(int8_t *dst, int8_t *src, int64_t k, int64_t height, int64_t width, int64_t out_elems)
void i16_tril_p(int16_t *dst, int16_t *src, int64_t k, int64_t height, int64_t width, int64_t out_elems)
void i32_tril_p(int32_t *dst, int32_t *src, int64_t k, int64_t height, int64_t width, int64_t out_elems)
void hp_tril_p(half *dst, half *src, int64_t k, int64_t height, int64_t width, int64_t out_elems)
void fp_tril_p(float *dst, float *src, int64_t k, int64_t height, int64_t width, int64_t out_elems)
void dp_tril_p(double *dst, double *src, int64_t k, int64_t height, int64_t width, int64_t out_elems)
void c64_tril_p(float (*dst)[2], float (*src)[2], int64_t k, int64_t height, int64_t width, int64_t out_elems)
void c128_tril_p(double (*dst)[2], double (*src)[2], int64_t k, int64_t height, int64_t width, int64_t out_elems)

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3#include <tril.h>
 4
 5int main(int argc, char* argv[]) {
 6    float *input = (float *)0x10810000;   //input在L2空间
 7    float *output = (float *)0x10850000;  //output在L2空间
 8    int64_t k = 0;           // 主对角线
 9    int64_t height = 4;      // 矩阵高度
10    int64_t width = 4;       // 矩阵宽度
11    int64_t out_elems = 1;   // 矩阵数量
12    fp_tril_p(output, input, k, height, width, out_elems);
13    return 0;
14}

备注

示例说明:

对于 4x4 矩阵,k=0 时的下三角输出示例:

输入矩阵:              输出矩阵:
1  2  3  4             1  0  0  0
5  6  7  8      =>     5  6  0  0
9  10 11 12            9  10 11 0
13 14 15 16            13 14 15 16

k=1 时(上移一条对角线,包含更多元素):

输入矩阵:              输出矩阵:
1  2  3  4             1  2  0  0
5  6  7  8      =>     5  6  7  0
9  10 11 12            9  10 11 12
13 14 15 16            13 14 15 16

k=-1 时(下移一条对角线,更严格的下三角):

输入矩阵:              输出矩阵:
1  2  3  4             0  0  0  0
5  6  7  8      =>     5  0  0  0
9  10 11 12            9  10 0  0
13 14 15 16            13 14 15 0